import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from models.spikingcnn import ConvNet, ConvNetDynamic
from functions.erf import compute_spatial_erf, compute_temporal_erf, compute_temporal_erf_2

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '5'

def visualize_erf(models_config, input_size=14, num_runs=50):
    """Visualize ERF for different model configurations."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_configs = len(models_config)
    fig, axes = plt.subplots(6, num_configs//6, figsize=(15, 8))
    axes = axes.flatten()
    
    for idx, config in enumerate(models_config):
        model = ConvNet(
            num_layers=config['layers'],
            kernel_size=config['kernel_size'],
            weight_type=config['weight_type'],
            activation=config['activation']
        ).to(device)
        print(model)
        
        # Compute ERF
        erf = compute_spatial_erf(model, input_size, num_runs)
        
        # Plot
        ax = axes[idx]
        im = ax.imshow(erf, cmap='gray')
        ax.set_title(f"{config['layers']} layers\n{config['weight_type']}\n{config['activation']}")
        ax.axis('off')
    
    plt.tight_layout()
    return fig

def visualize_temporal_erf(models_config, input_size=32, num_runs=20):
    """
    Visualize temporal ERF for different LIF configurations.
    
    Args:
        models_config: List of dictionaries containing model configurations
                      Each dict should have: {'tau': float, 'Vth': float, 
                                            'activation': str, 'title': str}
        input_size: Size of the spatial input dimension (H=W)
        T: Number of time steps
        num_runs: Number of random runs for averaging
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_configs = len(models_config)
    
    # Create subplot grid
    fig, axes = plt.subplots(7, num_configs//7, figsize=(15, 8))
    if num_configs <= 2:
        axes = np.array([axes])  # Ensure axes is 2D
    axes = axes.flatten()
    
    for idx, config in enumerate(models_config):
        # Create model with current configuration
        model = ConvNetDynamic(
            num_layers=config['layers'],
            kernel_size=config['kernel_size'],
            weight_type=config['weight_type'],
            tau=config['tau'],
            Vth=config['Vth'],
            surrogate_mode=config['surrogate_mode'],
            alpha=config['alpha'],
        ).to(device)

        print(model)

    # for idx, config in enumerate(models_config):
    #     model = ConvNet(
    #         num_layers=config['layers'],
    #         kernel_size=config['kernel_size'],
    #         weight_type=config['weight_type'],
    #         activation=config['activation']
    #     ).to(device)
        
        # Compute temporal ERF
        avg_temporal_grad = compute_temporal_erf_2(model, input_size, num_runs) # [T, (H*W)]
        # print(avg_temporal_grad.shape) 
        
        # Plot
        ax = axes[idx]
        time_steps = np.arange(len(avg_temporal_grad))
        ax.plot(time_steps, avg_temporal_grad, '-o', linewidth=2, markersize=4)

        ax.set_xlabel('Time Steps')
        ax.set_ylabel('Average Gradient')
        ax.set_title(f"{config['surrogate_mode']}\n tau={config['tau']}")
        ax.grid(True, linestyle='--', alpha=0.7)
        

    plt.tight_layout()
    return fig



if __name__ == "__main__":
    static_models_config = [
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'relu'},
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'tanh'},
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'sigmoid'},
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'MultispikeNorm4'},
        {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'Multispike4'},

        {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'relu'},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'tanh'},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'sigmoid'},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'MultispikeNorm4'},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'Multispike4'},

        {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
        {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
        {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'relu'},
        {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'tanh'},
        {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'sigmoid'},
        {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'MultispikeNorm4'},
        {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'Multispike4'},

        {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
        {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
        {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'relu'},
        {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'tanh'},
        {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'sigmoid'},
        {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'MultispikeNorm4'},
        {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'Multispike4'},

        {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
        {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
        {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'relu'},   
        {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'tanh'},
        {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'sigmoid'},
        {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'MultispikeNorm4'},
        {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'Multispike4'},

        {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
        {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
        {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'relu'},
        {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'tanh'},
        {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'sigmoid'},
        {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'MultispikeNorm4'},
        {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'Multispike4'},
    ]
    ################for static processing################
    # plt.figure(figsize=(15, 8))
    # fig = visualize_erf(
    #     models_config=static_models_config,
    #     input_size=56,
    #     num_runs=50)
    # plt.show()
    # # save figure
    # fig.savefig('static_erf_visualization.pdf')

    dynamic_models_config = [
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.0, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.0, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.0, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.0, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},

        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.1, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.1, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.1, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 1.1, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},

        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 2.0, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 2.0, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 2.0, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 2.0, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},

        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 4.0, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 4.0, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 4.0, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 4.0, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},
    
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 8.0, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 8.0, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 8.0, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 8.0, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},
    
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 16.0, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 16.0, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 16.0, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 16.0, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},
    
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 32.0, 'Vth': 1.0, 'surrogate_mode': 'triangle', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 32.0, 'Vth': 1.0, 'surrogate_mode': 'sigmoid', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 32.0, 'Vth': 1.0, 'surrogate_mode': 'arctan', 'alpha': 2.0},
        {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'tau': 32.0, 'Vth': 1.0, 'surrogate_mode': 'rectangle', 'alpha': 2.0},
    ]

    # dynamic_models_config = [
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'relu'},
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'tanh'},
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'sigmoid'},
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'MultispikeNorm4'},
    #     {'layers': 5, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'Multispike4'},

    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'relu'},
    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'tanh'},
    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'sigmoid'},
    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'MultispikeNorm4'},
    #     {'layers': 10, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'Multispike4'},

    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'relu'},
    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'tanh'},
    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'sigmoid'},
    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'MultispikeNorm4'},
    #     {'layers': 20, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'Multispike4'},

    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'relu'},
    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'tanh'},
    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'sigmoid'},
    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'MultispikeNorm4'},
    #     {'layers': 40, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'Multispike4'},

    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'relu'},
    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'tanh'},
    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'sigmoid'},
    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'MultispikeNorm4'},
    #     {'layers': 100, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'Multispike4'},

    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'uniform', 'activation': 'none'},
    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'none'},
    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'relu'},
    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'tanh'},
    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'sigmoid'},
    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'MultispikeNorm4'},
    #     {'layers': 200, 'kernel_size': 3, 'weight_type': 'random', 'activation': 'Multispike4'},
    # ]

    plt.figure(figsize=(15, 8))
    fig = visualize_erf(
        models_config=static_models_config,
        input_size=48,
        num_runs=50)
    plt.show()

    # save figure
    fig.savefig('visualization_gm_2.pdf')
    
    
